#!/usr/bin/env python3

"""
Filter that injects attributes into a JSON object.

Existing values are overwritten.

"""

from __future__ import annotations

import argparse
import hashlib
import json
import sys
from collections.abc import Iterable
from fnmatch import fnmatchcase
from pathlib import Path
from typing import Any

__author__ = 'Murray Andrews'

PROG = Path(sys.argv[0]).stem

HASH_FORMAT_VERSION = 1
HASH_ALGORITHM = 'sha256'
# A tag goes in the hash field to identify the source.
TAG_DEFAULT = 'lava'


# ------------------------------------------------------------------------------
class StoreNameValuePair(argparse.Action):
    """
    Store argpare values from options of the form --option name=value.

    The destination (self.dest) will be created as a dict {name: value}. This
    allows multiple name-value pairs to be set for the same option.

    Usage is:

        argparser.add_argument('-x', metavar='key=value', action=StoreNameValuePair)

    or
        argparser.add_argument('-x', metavar='key=value ...', action=StoreNameValuePair,
                               nargs='+')

    """

    # --------------------------------------------------------------------------
    def __call__(self, parser, namespace, values, option_string=None):
        """Handle name=value option."""

        if not hasattr(namespace, self.dest) or not getattr(namespace, self.dest):
            setattr(namespace, self.dest, {})
        argdict = getattr(namespace, self.dest)

        if not isinstance(values, list):
            values = [values]
        for val in values:
            try:
                n, v = val.split('=', 1)
            except ValueError as e:
                raise argparse.ArgumentError(self, str(e))
            argdict[n] = v


# ------------------------------------------------------------------------------
def dot_dict_set(d: dict[str, Any], key: str, value: Any) -> dict[str, Any]:
    """
    Set a dict element based on a hierarchical dot separated key.

    All the parent components of the key must be present in the dict or its a
    KeyError.

    :param d:       The dict to be modified
    :param key:     The compound key in the form `a.b.c...`.
    :param value:   The value to set.

    :return:        The dict.
    """

    dd = d
    keys = key.split('.')

    for k in keys[:-1]:
        dd = dd.setdefault(k, {})

    dd[keys[-1]] = value
    return d


# ------------------------------------------------------------------------------
def glob_strip(names: Iterable[str], patterns: str | Iterable[str]) -> set[str]:
    """
    Remove from an iterable of strings any that match any of the given patterns.

    Patterns are glob style. Case is significant.

    The result is returned as a set so any ordering is lost.

    :param names:       An iterable of strings to match.
    :param patterns:    A glob pattern or iterable of glob patterns.

    :return:            A set containing all input strings that don't match any
                        of the glob patterns.
    """

    names = set(names)
    if isinstance(patterns, str):
        patterns = [patterns]

    for p in patterns:
        names -= {n for n in names if fnmatchcase(n, p)}

    return names


# ------------------------------------------------------------------------------
def dict_hash(d: dict, ignore: str | Iterable[str] = None, algorithm: str = HASH_ALGORITHM) -> str:
    """
    Calculate an ASCII safe hash on a dictionary.

    .. warning::
        This is not cryptographically secure and should not be used for any
        security related purpose. It's only for change detection.

    :param d:       The dictionary. It must be JSON serialisable.
    :param ignore:  Ignore any keys that match the specified glob pattern
                    or list of patterns.
    :param algorithm: Hashing algorithm. Must be one of the values supported
                    by `hashlib.new()`.
    :return:        An ASCII safe hash.
    """

    if ignore:
        if isinstance(ignore, str):
            ignore = [ignore]
    else:
        ignore = set()

    keys_to_hash = glob_strip(d, ignore)
    dict_to_hash = {k: d[k] for k in keys_to_hash}
    data = json.dumps(dict_to_hash, sort_keys=True).encode('utf-8')
    return hashlib.new(algorithm, data).hexdigest()


# ------------------------------------------------------------------------------
def process_cli_args():
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(
        prog=PROG,
        description='Inject/replace values in a JSON object.',
        epilog=(
            'Data is read from stdin and written to stdout. '
            ' Processing of parameter settings is done in this order:'
            ' (1) Deep params (-d),'
            ' (2) JSON params (-j),'
            ' (3) Simple params (-p).'
        ),
    )

    hashp = argp.add_argument_group('hash injection arguments')
    hashp.add_argument(
        '--hash-name',
        action='store',
        metavar='KEY-NAME',
        help=(
            'Calculate a hash / checksum for the JSON object and insert it into'
            ' the object with the given key name.'
        ),
    )

    hashp.add_argument(
        '--hash-algorithm',
        action='store',
        choices=hashlib.algorithms_guaranteed,
        default=HASH_ALGORITHM,
        help=f'Hash algorithm to use for the --hash-name option. Default is {HASH_ALGORITHM}.',
    )

    hashp.add_argument(
        '--hash-ignore',
        nargs='*',
        metavar='PATTERN',
        help='When calculating a hash, ignore keys that match the specified glob style patterns.',
    )

    argp.add_argument(
        '--hash-tag',
        action='store',
        default=TAG_DEFAULT,
        help=f'Add the specified tag into the hash field. Default is {TAG_DEFAULT}.',
    )

    paramp = argp.add_argument_group('parameter injection arguments')

    paramp.add_argument(
        '-d',
        '--deep-param',
        action=StoreNameValuePair,
        metavar='name.name...=value',
        default={},
        help=(
            'Set the value of an element in the JSON object specified by a dot'
            ' separated path. The value is injected as a string element. Objects'
            ' will be created at each level of the hierarchy, as required,'
            ' overwriting existing values.'
            ' Can be used multiple times.'
        ),
    )

    paramp.add_argument(
        '-j',
        '--json-param',
        action=StoreNameValuePair,
        metavar='name.name...=JSON',
        default={},
        help=(
            'Set the value of an element in the JSON object specified by a dot'
            ' separated path. The value must be valid JSON and will be decoded'
            ' prior to injection. Objects will be created at each level of the'
            ' hierarchy, as required, overwriting existing values.'
            ' Can be used multiple times.'
        ),
    )

    paramp.add_argument(
        '-p',
        '--param',
        action=StoreNameValuePair,
        metavar='name=value',
        default={},
        help=(
            'Set the value of a top-level element in the JSON object. The'
            ' parameter name is used literally.'
            ' Can be used multiple times.'
        ),
    )

    argp.add_argument(
        '-S',
        '--no-sort',
        dest='sort_keys',
        action='store_false',
        help='Do not sort keys in output JSON.',
    )

    argp.add_argument(
        '--strict',
        action='store_true',
        help=(
            ' Force an error exit if the input is valid JSON but not an object.'
            ' By default, such inputs are passed through, possibly reformatted,'
            ' but otherwise unmodified.'
        ),
    )

    return argp.parse_args()


# ------------------------------------------------------------------------------
def main() -> int:
    """
    Show time.

    :return:    status

    """

    args = process_cli_args()

    obj = json.load(sys.stdin)
    if args.strict and not isinstance(obj, dict):
        raise ValueError('Input must be a JSON object (dictionary)')

    if isinstance(obj, dict):
        # Path based params with string values
        for key, value in args.deep_param.items():
            try:
                obj = dot_dict_set(obj, key, value)
            except Exception as e:
                raise Exception(f'Cannot set {key}: {e}')

        # Path based params with JSON values
        for key, value in args.json_param.items():
            try:
                v = json.loads(value)
            except Exception as e:
                raise Exception(f'Bad JSON value for {key}: {e}')
            try:
                obj = dot_dict_set(obj, key, v)
            except Exception as e:
                raise Exception(f'Cannot set {key}: {e}')

        # Simple params
        obj.update(args.param)

        if args.hash_name:
            hash_val = dict_hash(obj, ignore=args.hash_ignore, algorithm=args.hash_algorithm)
            obj[args.hash_name] = (
                f'{HASH_FORMAT_VERSION};{args.hash_tag};{args.hash_algorithm};{hash_val}'
            )

    json.dump(obj, sys.stdout, sort_keys=args.sort_keys, indent=4)
    print()

    return 0


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        print(f'{PROG}: {ex}', file=sys.stderr)
        exit(1)
